In [1]:
# TODO drop Kherson point
# TODO investigate Unnamed
In [2]:
import pandas as pd
import seaborn as sns
In [3]:
SEED = 14
In [4]:
data = pd.read_parquet("cleaned_data/data.parquet")
print(data.info())
<class 'pandas.core.frame.DataFrame'>
Int64Index: 6010 entries, 0 to 6010
Data columns (total 12 columns):
 #   Column                  Non-Null Count  Dtype         
---  ------                  --------------  -----         
 0   hotspot_id              6010 non-null   int64         
 1   blacklist_score         6010 non-null   float64       
 2   static_score            6010 non-null   float64       
 3   dynamic_score           6010 non-null   float64       
 4   connection_stats_score  6010 non-null   float64       
 5   last_conn_date          6010 non-null   datetime64[ns]
 6   last_seen_date          6010 non-null   datetime64[ns]
 7   num_conn                6010 non-null   int64         
 8   unique_conn             6010 non-null   int64         
 9   percent_available       6010 non-null   float64       
 10  percent_protected       6010 non-null   float64       
 11  enabled_moderator       6010 non-null   bool          
dtypes: bool(1), datetime64[ns](2), float64(6), int64(3)
memory usage: 569.3 KB
None
In [ ]:
 

Analysis¶

In [ ]:
 
In [5]:
data
Out[5]:
hotspot_id blacklist_score static_score dynamic_score connection_stats_score last_conn_date last_seen_date num_conn unique_conn percent_available percent_protected enabled_moderator
0 14650480 0.0 0.22 0.45 0.69 2022-08-21 2021-05-01 5 2 1.0 1.0 True
1 14110275 0.0 0.22 0.00 0.67 2022-02-04 2022-02-04 4 2 1.0 1.0 True
2 16012785 0.0 0.18 0.16 0.67 2022-02-05 2022-02-15 4 2 1.0 1.0 True
3 14863945 0.0 0.22 0.05 0.72 2022-03-15 2021-06-12 6 2 1.0 0.8 True
4 9295867 0.0 0.39 0.00 0.52 2017-12-07 2017-12-07 1 1 1.0 0.0 True
... ... ... ... ... ... ... ... ... ... ... ... ...
6006 13213372 0.0 0.25 0.00 0.52 2020-04-27 2020-04-27 1 1 1.0 0.0 True
6007 5504114 0.0 0.22 0.00 0.52 2016-04-10 2022-01-02 1 1 1.0 0.0 True
6008 15109612 0.0 0.48 0.00 0.52 2021-07-20 2021-07-20 1 1 1.0 0.0 True
6009 502326 0.0 0.12 0.00 0.52 2013-11-19 2013-11-19 1 1 1.0 0.0 True
6010 14242378 0.0 0.22 0.00 0.52 2021-01-29 2021-01-29 1 1 1.0 0.0 True

6010 rows × 12 columns

In [7]:
data.nunique()
Out[7]:
hotspot_id                6003
blacklist_score              2
static_score                79
dynamic_score               90
connection_stats_score      22
last_conn_date            2243
last_seen_date            2059
num_conn                   615
unique_conn                 57
percent_available            2
percent_protected           57
enabled_moderator            1
dtype: int64
In [ ]:
 
In [8]:
# small_variaty_columns = hotspots_nunique[hotspots_nunique < 50].index

# for column in small_variaty_columns:
#     print(column)
#     print(hotspots[column].value_counts(dropna = False))
#     print()

Features Analysis¶

In [9]:
SORTED_QUALITY = ["spam", "bad", "moderate", "good"]
def calculate_quality(scores):
    def calculate_quality_for_row(row):
        blacklist_score, dynamic_score = row["blacklist_score"], row["dynamic_score"]
        if blacklist_score == 1:
            return "spam"
        if dynamic_score < 0.3:
            return "bad"
        elif dynamic_score >= 0.3 and dynamic_score < 0.6:
            return "moderate"

        return "good"
    quality = scores.apply(calculate_quality_for_row, axis = 1)\
        .astype("category").cat.reorder_categories(SORTED_QUALITY)
    return quality

def calculate_quality_code(scores):
    quality = calculate_quality(scores)
    return quality.cat.codes.rename("quality_cat_id")

print(calculate_quality(data).value_counts())
print()
bad         3787
good        1177
moderate     949
spam          97
dtype: int64

In [10]:
scores = data.copy()

scores["quality_cat_id"] = calculate_quality_code(scores)
print()
print(scores["quality_cat_id"].value_counts())
1    3787
3    1177
2     949
0      97
Name: quality_cat_id, dtype: int64
In [11]:
default_pairplot = sns.pairplot(scores, hue = "quality_cat_id", palette = {
    0: "black",
    1: "red",
    2: "blue",
    3: "green",
}, height = 2)
default_pairplot
Out[11]:
<seaborn.axisgrid.PairGrid at 0x7f98bb406a40>
In [53]:
sns.pairplot(scores, hue = "quality_cat_id", kind="hist", height=1.5)
Out[53]:
<seaborn.axisgrid.PairGrid at 0x7f0d6e7b9510>
In [52]:
sns.pairplot(scores, kind="hist", height=1.5)
Out[52]:
<seaborn.axisgrid.PairGrid at 0x7f0d6999da50>
In [ ]:
sns.pairplot(scores, hue = "quality_cat_id", palette = {
    0: "black",
    1: "red",
    2: "blue",
    3: "green",
}, height=1.5)

Model¶

In [12]:
import lightgbm as lgbm 
In [13]:
data
Out[13]:
hotspot_id blacklist_score static_score dynamic_score connection_stats_score last_conn_date last_seen_date num_conn unique_conn percent_available percent_protected enabled_moderator
0 14650480 0.0 0.22 0.45 0.69 2022-08-21 2021-05-01 5 2 1.0 1.0 True
1 14110275 0.0 0.22 0.00 0.67 2022-02-04 2022-02-04 4 2 1.0 1.0 True
2 16012785 0.0 0.18 0.16 0.67 2022-02-05 2022-02-15 4 2 1.0 1.0 True
3 14863945 0.0 0.22 0.05 0.72 2022-03-15 2021-06-12 6 2 1.0 0.8 True
4 9295867 0.0 0.39 0.00 0.52 2017-12-07 2017-12-07 1 1 1.0 0.0 True
... ... ... ... ... ... ... ... ... ... ... ... ...
6006 13213372 0.0 0.25 0.00 0.52 2020-04-27 2020-04-27 1 1 1.0 0.0 True
6007 5504114 0.0 0.22 0.00 0.52 2016-04-10 2022-01-02 1 1 1.0 0.0 True
6008 15109612 0.0 0.48 0.00 0.52 2021-07-20 2021-07-20 1 1 1.0 0.0 True
6009 502326 0.0 0.12 0.00 0.52 2013-11-19 2013-11-19 1 1 1.0 0.0 True
6010 14242378 0.0 0.22 0.00 0.52 2021-01-29 2021-01-29 1 1 1.0 0.0 True

6010 rows × 12 columns

In [14]:
y = calculate_quality_code(data)
print(y.info())
print()

insight_columns = ["blacklist_score", "hotspot_id", "dynamic_score"]
X = data.drop(columns = insight_columns)
print(X.info())
<class 'pandas.core.series.Series'>
Int64Index: 6010 entries, 0 to 6010
Series name: quality_cat_id
Non-Null Count  Dtype
--------------  -----
6010 non-null   int8 
dtypes: int8(1)
memory usage: 181.9 KB
None

<class 'pandas.core.frame.DataFrame'>
Int64Index: 6010 entries, 0 to 6010
Data columns (total 9 columns):
 #   Column                  Non-Null Count  Dtype         
---  ------                  --------------  -----         
 0   static_score            6010 non-null   float64       
 1   connection_stats_score  6010 non-null   float64       
 2   last_conn_date          6010 non-null   datetime64[ns]
 3   last_seen_date          6010 non-null   datetime64[ns]
 4   num_conn                6010 non-null   int64         
 5   unique_conn             6010 non-null   int64         
 6   percent_available       6010 non-null   float64       
 7   percent_protected       6010 non-null   float64       
 8   enabled_moderator       6010 non-null   bool          
dtypes: bool(1), datetime64[ns](2), float64(4), int64(2)
memory usage: 557.5 KB
None
In [16]:
# pd.concat([data.loc[y_train.index], y_train], axis = 1)
In [17]:
# X_y_pairplot = sns.pairplot(pd.concat([X, y], axis = 1), hue = y.name, palette = {
#     0: "black",
#     1: "red",
#     2: "blue",
#     3: "green",
# }, height = 2)
# X_y_pairplot
In [18]:
# X_y_pairplot_2 = sns.pairplot(pd.concat([X, y], axis = 1), height = 2)
# X_y_pairplot_2
In [19]:
import pandas as pd
import numpy as np
today = pd.Timestamp.now()

from lib import calculate_days_passed

X["last_conn_days"] = calculate_days_passed(X["last_conn_date"], today)
X["last_seen_days"] = calculate_days_passed(X["last_seen_date"], today)

X.drop(columns = ["last_conn_date", "last_seen_date"], inplace = True)
In [20]:
X.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 6010 entries, 0 to 6010
Data columns (total 9 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   static_score            6010 non-null   float64
 1   connection_stats_score  6010 non-null   float64
 2   num_conn                6010 non-null   int64  
 3   unique_conn             6010 non-null   int64  
 4   percent_available       6010 non-null   float64
 5   percent_protected       6010 non-null   float64
 6   enabled_moderator       6010 non-null   bool   
 7   last_conn_days          6010 non-null   int16  
 8   last_seen_days          6010 non-null   int16  
dtypes: bool(1), float64(4), int16(2), int64(2)
memory usage: 487.1 KB
In [21]:
# default_pairplot = sns.pairplot(scores, hue = "quality_cat_id", palette = {
#     0: "black",
#     1: "red",
#     2: "blue",
#     3: "green",
# }, height = 2)
# default_pairplot
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Input In [21], in <cell line: 1>()
----> 1 default_pairplot = sns.pairplot(scores, hue = "quality_cat_id", palette = {
      2     0: "black",
      3     1: "red",
      4     2: "blue",
      5     3: "green",
      6 }, height = 2)
      7 default_pairplot

File ~/.local/lib/python3.10/site-packages/seaborn/_decorators.py:46, in _deprecate_positional_args.<locals>.inner_f(*args, **kwargs)
     36     warnings.warn(
     37         "Pass the following variable{} as {}keyword arg{}: {}. "
     38         "From version 0.12, the only valid positional argument "
   (...)
     43         FutureWarning
     44     )
     45 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 46 return f(**kwargs)

File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:2140, in pairplot(data, hue, hue_order, palette, vars, x_vars, y_vars, kind, diag_kind, markers, height, aspect, corner, dropna, plot_kws, diag_kws, grid_kws, size)
   2138 if kind == "scatter":
   2139     from .relational import scatterplot  # Avoid circular import
-> 2140     plotter(scatterplot, **plot_kws)
   2141 elif kind == "reg":
   2142     from .regression import regplot  # Avoid circular import

File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:1387, in PairGrid.map_offdiag(self, func, **kwargs)
   1376 """Plot with a bivariate function on the off-diagonal subplots.
   1377 
   1378 Parameters
   (...)
   1384 
   1385 """
   1386 if self.square_grid:
-> 1387     self.map_lower(func, **kwargs)
   1388     if not self._corner:
   1389         self.map_upper(func, **kwargs)

File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:1357, in PairGrid.map_lower(self, func, **kwargs)
   1346 """Plot with a bivariate function on the lower diagonal subplots.
   1347 
   1348 Parameters
   (...)
   1354 
   1355 """
   1356 indices = zip(*np.tril_indices_from(self.axes, -1))
-> 1357 self._map_bivariate(func, indices, **kwargs)
   1358 return self

File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:1539, in PairGrid._map_bivariate(self, func, indices, **kwargs)
   1537     if ax is None:  # i.e. we are in corner mode
   1538         continue
-> 1539     self._plot_bivariate(x_var, y_var, ax, func, **kws)
   1540 self._add_axis_labels()
   1542 if "hue" in signature(func).parameters:

File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:1579, in PairGrid._plot_bivariate(self, x_var, y_var, ax, func, **kwargs)
   1577 kwargs.setdefault("hue_order", self._hue_order)
   1578 kwargs.setdefault("palette", self._orig_palette)
-> 1579 func(x=x, y=y, **kwargs)
   1581 self._update_legend_data(ax)

File ~/.local/lib/python3.10/site-packages/seaborn/_decorators.py:46, in _deprecate_positional_args.<locals>.inner_f(*args, **kwargs)
     36     warnings.warn(
     37         "Pass the following variable{} as {}keyword arg{}: {}. "
     38         "From version 0.12, the only valid positional argument "
   (...)
     43         FutureWarning
     44     )
     45 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 46 return f(**kwargs)

File ~/.local/lib/python3.10/site-packages/seaborn/relational.py:827, in scatterplot(x, y, hue, style, size, data, palette, hue_order, hue_norm, sizes, size_order, size_norm, markers, style_order, x_bins, y_bins, units, estimator, ci, n_boot, alpha, x_jitter, y_jitter, legend, ax, **kwargs)
    823     return ax
    825 p._attach(ax)
--> 827 p.plot(ax, kwargs)
    829 return ax

File ~/.local/lib/python3.10/site-packages/seaborn/relational.py:670, in _ScatterPlotter.plot(self, ax, kws)
    668 self._add_axis_labels(ax)
    669 if self.legend:
--> 670     self.add_legend_data(ax)
    671     handles, _ = ax.get_legend_handles_labels()
    672     if handles:

File ~/.local/lib/python3.10/site-packages/seaborn/relational.py:337, in _RelationalPlotter.add_legend_data(self, ax)
    335     if attr in kws:
    336         use_kws[attr] = kws[attr]
--> 337 artist = func([], [], label=label, **use_kws)
    338 if self._legend_func == "plot":
    339     artist = artist[0]

File ~/.local/lib/python3.10/site-packages/matplotlib/__init__.py:1601, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1598 @functools.wraps(func)
   1599 def inner(ax, *args, data=None, **kwargs):
   1600     if data is None:
-> 1601         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1603     bound = new_sig.bind(ax, *args, **kwargs)
   1604     needs_label = (label_namer
   1605                    and "label" not in bound.arguments
   1606                    and "label" not in bound.kwargs)

File ~/.local/lib/python3.10/site-packages/matplotlib/axes/_axes.py:4528, in Axes.scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, plotnonfinite, **kwargs)
   4525         self.set_ymargin(0.05)
   4527 self.add_collection(collection)
-> 4528 self.autoscale_view()
   4530 return collection

File ~/.local/lib/python3.10/site-packages/matplotlib/axes/_base.py:2496, in _AxesBase.autoscale_view(self, tight, scalex, scaley)
   2491     # End of definition of internal function 'handle_single_axis'.
   2493 handle_single_axis(
   2494     scalex, self._autoscaleXon, self._shared_x_axes, 'intervalx',
   2495     'minposx', self.xaxis, self._xmargin, x_stickies, self.set_xbound)
-> 2496 handle_single_axis(
   2497     scaley, self._autoscaleYon, self._shared_y_axes, 'intervaly',
   2498     'minposy', self.yaxis, self._ymargin, y_stickies, self.set_ybound)

File ~/.local/lib/python3.10/site-packages/matplotlib/axes/_base.py:2449, in _AxesBase.autoscale_view.<locals>.handle_single_axis(scale, autoscaleon, shared_axes, interval, minpos, axis, margin, stickies, set_bound)
   2446     dl.extend(x_finite)
   2447     dl.extend(y_finite)
-> 2449 bb = mtransforms.BboxBase.union(dl)
   2450 # fall back on the viewlimits if this is not finite:
   2451 vl = None

File ~/.local/lib/python3.10/site-packages/matplotlib/transforms.py:703, in BboxBase.union(bboxes)
    701     raise ValueError("'bboxes' cannot be empty")
    702 x0 = np.min([bbox.xmin for bbox in bboxes])
--> 703 x1 = np.max([bbox.xmax for bbox in bboxes])
    704 y0 = np.min([bbox.ymin for bbox in bboxes])
    705 y1 = np.max([bbox.ymax for bbox in bboxes])

File ~/.local/lib/python3.10/site-packages/matplotlib/transforms.py:703, in <listcomp>(.0)
    701     raise ValueError("'bboxes' cannot be empty")
    702 x0 = np.min([bbox.xmin for bbox in bboxes])
--> 703 x1 = np.max([bbox.xmax for bbox in bboxes])
    704 y0 = np.min([bbox.ymin for bbox in bboxes])
    705 y1 = np.max([bbox.ymax for bbox in bboxes])

File ~/.local/lib/python3.10/site-packages/matplotlib/transforms.py:360, in BboxBase.xmax(self)
    357 @property
    358 def xmax(self):
    359     """The right edge of the bounding box."""
--> 360     return np.max(self.get_points()[:, 0])

File <__array_function__ internals>:180, in amax(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2791, in amax(a, axis, out, keepdims, initial, where)
   2675 @array_function_dispatch(_amax_dispatcher)
   2676 def amax(a, axis=None, out=None, keepdims=np._NoValue, initial=np._NoValue,
   2677          where=np._NoValue):
   2678     """
   2679     Return the maximum of an array or maximum along an axis.
   2680 
   (...)
   2789     5
   2790     """
-> 2791     return _wrapreduction(a, np.maximum, 'max', axis, None, out,
   2792                           keepdims=keepdims, initial=initial, where=where)

File ~/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:73, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
     69 def _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs):
     70     passkwargs = {k: v for k, v in kwargs.items()
     71                   if v is not np._NoValue}
---> 73     if type(obj) is not mu.ndarray:
     74         try:
     75             reduction = getattr(obj, method)

KeyboardInterrupt: 
Error in callback <function flush_figures at 0x7f98bdeb4820> (for post_execute):
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
File ~/.local/lib/python3.10/site-packages/matplotlib_inline/backend_inline.py:121, in flush_figures()
    118 if InlineBackend.instance().close_figures:
    119     # ignore the tracking, just draw and close all figures
    120     try:
--> 121         return show(True)
    122     except Exception as e:
    123         # safely show traceback if in IPython, else raise
    124         ip = get_ipython()

File ~/.local/lib/python3.10/site-packages/matplotlib_inline/backend_inline.py:43, in show(close, block)
     39 try:
     40     for figure_manager in Gcf.get_all_fig_managers():
     41         display(
     42             figure_manager.canvas.figure,
---> 43             metadata=_fetch_figure_metadata(figure_manager.canvas.figure)
     44         )
     45 finally:
     46     show._to_draw = []

File ~/.local/lib/python3.10/site-packages/matplotlib_inline/backend_inline.py:231, in _fetch_figure_metadata(fig)
    228 # determine if a background is needed for legibility
    229 if _is_transparent(fig.get_facecolor()):
    230     # the background is transparent
--> 231     ticksLight = _is_light([label.get_color()
    232                             for axes in fig.axes
    233                             for axis in (axes.xaxis, axes.yaxis)
    234                             for label in axis.get_ticklabels()])
    235     if ticksLight.size and (ticksLight == ticksLight[0]).all():
    236         # there are one or more tick labels, all with the same lightness
    237         return {'needs_background': 'dark' if ticksLight[0] else 'light'}

File ~/.local/lib/python3.10/site-packages/matplotlib_inline/backend_inline.py:234, in <listcomp>(.0)
    228 # determine if a background is needed for legibility
    229 if _is_transparent(fig.get_facecolor()):
    230     # the background is transparent
    231     ticksLight = _is_light([label.get_color()
    232                             for axes in fig.axes
    233                             for axis in (axes.xaxis, axes.yaxis)
--> 234                             for label in axis.get_ticklabels()])
    235     if ticksLight.size and (ticksLight == ticksLight[0]).all():
    236         # there are one or more tick labels, all with the same lightness
    237         return {'needs_background': 'dark' if ticksLight[0] else 'light'}

File ~/.local/lib/python3.10/site-packages/matplotlib/axis.py:1296, in Axis.get_ticklabels(self, minor, which)
   1294 if minor:
   1295     return self.get_minorticklabels()
-> 1296 return self.get_majorticklabels()

File ~/.local/lib/python3.10/site-packages/matplotlib/axis.py:1252, in Axis.get_majorticklabels(self)
   1250 def get_majorticklabels(self):
   1251     'Return a list of Text instances for the major ticklabels.'
-> 1252     ticks = self.get_major_ticks()
   1253     labels1 = [tick.label1 for tick in ticks if tick.label1.get_visible()]
   1254     labels2 = [tick.label2 for tick in ticks if tick.label2.get_visible()]

File ~/.local/lib/python3.10/site-packages/matplotlib/axis.py:1407, in Axis.get_major_ticks(self, numticks)
   1405 'Get the tick instances; grow as necessary.'
   1406 if numticks is None:
-> 1407     numticks = len(self.get_majorticklocs())
   1409 while len(self.majorTicks) < numticks:
   1410     # Update the new tick label properties from the old.
   1411     tick = self._get_tick(major=True)

File ~/.local/lib/python3.10/site-packages/matplotlib/axis.py:1324, in Axis.get_majorticklocs(self)
   1322 def get_majorticklocs(self):
   1323     """Get the array of major tick locations in data coordinates."""
-> 1324     return self.major.locator()

File ~/.local/lib/python3.10/site-packages/matplotlib/ticker.py:2078, in MaxNLocator.__call__(self)
   2076 def __call__(self):
   2077     vmin, vmax = self.axis.get_view_interval()
-> 2078     return self.tick_values(vmin, vmax)

File ~/.local/lib/python3.10/site-packages/matplotlib/ticker.py:2084, in MaxNLocator.tick_values(self, vmin, vmax)
   2082     vmax = max(abs(vmin), abs(vmax))
   2083     vmin = -vmax
-> 2084 vmin, vmax = mtransforms.nonsingular(
   2085     vmin, vmax, expander=1e-13, tiny=1e-14)
   2086 locs = self._raw_ticks(vmin, vmax)
   2088 prune = self._prune

File ~/.local/lib/python3.10/site-packages/matplotlib/transforms.py:2828, in nonsingular(vmin, vmax, expander, tiny, increasing)
   2825     swapped = True
   2827 maxabsvalue = max(abs(vmin), abs(vmax))
-> 2828 if maxabsvalue < (1e6 / tiny) * np.finfo(float).tiny:
   2829     vmin = -expander
   2830     vmax = expander

File ~/.local/lib/python3.10/site-packages/numpy/core/getlimits.py:577, in finfo.tiny(self)
    562 @property
    563 def tiny(self):
    564     """Return the value for tiny, alias of smallest_normal.
    565 
    566     Returns
   (...)
    575         double-double.
    576     """
--> 577     return self.smallest_normal

File ~/.local/lib/python3.10/site-packages/numpy/core/getlimits.py:556, in finfo.smallest_normal(self)
    541 """Return the value for the smallest normal.
    542 
    543 Returns
   (...)
    552     double-double.
    553 """
    554 # This check is necessary because the value for smallest_normal is
    555 # platform dependent for longdouble types.
--> 556 if isnan(self._machar.smallest_normal.flat[0]):
    557     warnings.warn(
    558         'The value of smallest normal is undefined for double double',
    559         UserWarning, stacklevel=2)
    560 return self._machar.smallest_normal.flat[0]

KeyboardInterrupt: 

Train¶

In [22]:
X_for_exp = X #.drop(columns = ["connection_stats_score", "percent_protected", "last_conn_days", "static_score"])
X_for_exp = X[["connection_stats_score", "last_seen_days", "last_conn_days"]]
y_for_exp = y.loc[X.index]
In [24]:
X_for_exp["last_seen_days"].hist()
Out[24]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f98100540a0>
In [25]:
from typing import List
import plotly.graph_objs as go
from sklearn.metrics import confusion_matrix

def confusion_matrix_plot(y_test: List[str], y_pred: List[str], labels: List[str], display_labels: List[str] = None, normalise: bool = False) -> go.Figure:
    # Compute the confusion matrix
    cm = confusion_matrix(y_test, y_pred, labels=labels)

    # Normalize the matrix by rows
    if normalise:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # Define the display labels
    if display_labels is None:
        display_labels = labels

    # Define the data for the heatmap
    data = go.Heatmap(
        z=cm,
        x=display_labels,
        y=display_labels,
        colorscale='YlGnBu'
    )

    # Define the layout of the plot
    layout = go.Layout(
        title='Confusion Matrix',
        xaxis=dict(title='Predicted label'),
        yaxis=dict(title='True label')
    )

    # Create the plot
    fig = go.Figure(data=[data], layout=layout)
    return fig
In [ ]:
 
In [26]:
from lightgbm import LGBMClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

# Load your data and split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_for_exp, y_for_exp, test_size=0.2, random_state=SEED)



# Create an instance of the LGBMClassifier
params = {
    'num_classes': 4,
    'objective': 'multiclass',
    'random_state': SEED,
    'n_estimators': 100,
    'max_depth': 10,
    'learning_rate': 0.1
}

lgbm = LGBMClassifier(**params)

# Fit the model on the training data
lgbm.fit(X_train, y_train)

# Predict the class labels of the testing data
y_pred = lgbm.predict(X_test)
print("Train classes: ")
print(y_train.value_counts())
print("Test classes: ")
print(y_test.value_counts())
print("Test predicts: ")
print( pd.DataFrame(y_pred).value_counts() )
# Print the classification report
print(classification_report(y_test, y_pred))

correlations = pd.concat([X_for_exp, y_for_exp], axis=1).corr()

# Print the correlation between each feature and the target variable
print("Correlation to quality")
print(correlations[y.name].iloc[:-1])

cm = confusion_matrix_plot(y_test, y_pred, lgbm.classes_, display_labels = SORTED_QUALITY, normalise = True)
cm.show()

cm = confusion_matrix_plot(y_test, y_pred, lgbm.classes_, display_labels = SORTED_QUALITY, normalise = False)
cm.show()
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000113 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 529
[LightGBM] [Info] Number of data points in the train set: 4808, number of used features: 3
[LightGBM] [Info] Start training from score -4.096010
[LightGBM] [Info] Start training from score -0.461389
[LightGBM] [Info] Start training from score -1.838161
[LightGBM] [Info] Start training from score -1.640704
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
Train classes: 
1    3031
3     932
2     765
0      80
Name: quality_cat_id, dtype: int64
Test classes: 
1    756
3    245
2    184
0     17
Name: quality_cat_id, dtype: int64
Test predicts: 
1    766
3    225
2    209
0      2
dtype: int64
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        17
           1       0.96      0.97      0.96       756
           2       0.76      0.86      0.81       184
           3       0.96      0.88      0.92       245

    accuracy                           0.92      1202
   macro avg       0.67      0.68      0.67      1202
weighted avg       0.91      0.92      0.92      1202

Correlation to quality
connection_stats_score    0.859035
last_seen_days           -0.585585
last_conn_days           -0.620113
Name: quality_cat_id, dtype: float64
In [ ]:
 
In [27]:
# Get feature importances and column names
importances = lgbm.feature_importances_
features = X_for_exp.columns

# Create a list of tuples of feature names and importances, sorted by importance
feature_importances = [(feature, importance) for feature, importance in zip(features, importances)]
feature_importances = sorted(feature_importances, key=lambda x: x[1], reverse=True)

# Print the sorted list of features and their importances
for feature, importance in feature_importances:
    print('{}: {}'.format(feature, importance))
last_seen_days: 5287
last_conn_days: 5155
connection_stats_score: 1540
In [49]:
import pickle
# Save the model to a file using pickle
with open('lgbm_model.pkl', 'wb') as f:
    pickle.dump(lgbm, f)
In [ ]: